clear all

save_dirc = "\\10.229.62.137\data\Experiment folders\ML Inverse Problem\ML Current Inversion Paper (2024)\training and validation data\type2_50um\";
datafolder =  "\\10.229.62.137\data\Experiment folders\ML Inverse Problem\ML Current Inversion Paper (2024)\training and validation data\type2 raw\";

% permeability of free space
u0= 4*pi*(10^-7);

max_grad_strength_xy=5e-5;
max_grad_strength_z=0;
add_gradients=false; 
add_noise=false;
%% Set up current sources

device_x_dimension = 2000*10^-6;    

%relative standard deviation of current
J_std=0.1;

%resolution
res=512;

dx= device_x_dimension/res;

%padding
padding=44;

% number of pixels of image
number_x_pixels = res+2*padding;

im_x_dim=dx*number_x_pixels;

%standoff
standoff = 50*10^-6;
standoff_std=10*10^-6;

%depth of current source
d=14*10^-6;

%current
tot_current=0.05;

%current density 
base_current=tot_current/(d*76e-6);

%optical diffraction
sigma=2*10^-6;

% x, y, and z of current source
x_vector = linspace(-im_x_dim/2,im_x_dim/2,number_x_pixels);
y_vector = linspace(-im_x_dim/2,im_x_dim/2,number_x_pixels);

% make meshgrids
[x_grid,y_grid] = meshgrid(x_vector,y_vector);
[XX,YY]=meshgrid((1:number_x_pixels)/number_x_pixels,(1:number_x_pixels)/number_x_pixels);

% kx vector
delta_x = x_vector(end)-x_vector(end-1);
delta_kx = 1/im_x_dim;
kx_max = 0.5/delta_x;
kx_vector = (-kx_max:delta_kx:kx_max);
kx_vector = kx_vector(kx_vector<0);             % doing this because there was a problem with positives and negatives were not equal
kx_vector_positives = -kx_vector;
kx_vector = sort([kx_vector,kx_vector_positives]);

% ky vector
delta_y = y_vector(end)-y_vector(end-1);
delta_ky = 1/im_x_dim;
ky_max = 0.5/delta_y;
ky_vector = (-ky_max:delta_ky:ky_max);
ky_vector = ky_vector(ky_vector<0);             % doing this because there was a problem with positives and negatives were not equal
ky_vector_positives = -ky_vector;
ky_vector = sort([ky_vector,ky_vector_positives]);

kx_vector = kx_vector*2*pi;
ky_vector = ky_vector*2*pi;

[kx_2D,ky_2D] = meshgrid(kx_vector,ky_vector);
k_perp_2D = sqrt(kx_2D.^2+ky_2D.^2);

M=512;
num_batches=156;

JDat_1024=zeros(M,1024,1024,2);
BDat_1024=zeros(M,1024,1024,3);

z=zeros(1,M);
A=zeros(1,M);
Complex=zeros(1,M);
J_sign=zeros(1,M);

skip_count=0;
end_indx=80487;

%% Forward problem (generate magnetic field) Fourier calculations - slice by slice

for nb=1:num_batches
    clear JDat_1024
    clear BDat_1024
    clear JDat_64
    clear BDat_64
    clear JDat_128
    clear BDat_128
    clear jx
    clear jy
    clear bx
    clear by
    clear bz
    clear BX
    clear BY
    clear BZ
    clear Dat_64
    clear Dat_128
    clear data
    
    tic
    nb
    
    JDat_1024=zeros(M,1024,1024,2);
    BDat_1024=zeros(M,1024,1024,3);
    
    z=zeros(1,M);
    A=zeros(1,M);
    Complex=zeros(1,M);
    J_sign=zeros(1,M);
    
for k=1:M

J=zeros(number_x_pixels,number_x_pixels);
Bx=zeros(number_x_pixels,number_x_pixels);
By=zeros(number_x_pixels,number_x_pixels);
Bz=zeros(number_x_pixels,number_x_pixels);

%% load data
if (nb-1)*M+k<8406
    dataindex = num2str((nb-1)*M+k,'%04.f');
else
    dataindex = num2str((nb-1)*M+k,'%05.f');
end 

dataappendix = ".txt";
completedatapath = datafolder+dataindex+dataappendix;
data = importdata(completedatapath);

%% clean up NaN values
data(isnan(data)) = 0;

%% store and process data
datasize=size(data);

while datasize(2)<4 || datasize(1) ~=361201
    dataindex = end_indx-skip_count;
    completedatapath = datafolder+dataindex+dataappendix;
    data = importdata(completedatapath);
    skip_count=skip_count+1;
    datasize=size(data);
end 

Jx = data(:,3); % order following saving settings in COMSOLstudy.m
Jy = data(:,4);

Jx = reshape(Jx,[number_x_pixels+1,number_x_pixels+1])';
Jy = reshape(Jy,[number_x_pixels+1,number_x_pixels+1])';
Jx=Jx(2:end,2:end);
Jy=Jy(2:end,2:end);

% Jx = gpuArray(Jx);
% Jy = gpuArray(Jy);

%stand-off distance
z(k)=standoff+randn*standoff_std;

%Parabolic Currents
% W=k*50; %Channel Width
% 
% for x=-W/2:W/2
%     i=x+600;
%     Jy(:,i)=(W/2)^2-x^2;
% end

% %Separation test
% sep64=k;
% sep=k*16; %separation, 1024x1024 resolution 
% 
% % Jx(500:550,:)=base_current;
% 
% Jy(:,516:532)=base_current;
% Jy(:,532+sep:516+32+sep)=base_current;

%random reflections and rotations

if randi(2)==2
    Jx=-Jx;
    Jy=-Jy;
end 

if randi(2)==2
    Jx=-rot90(rot90(Jx));
    Jy=-rot90(rot90(Jy));
end 

if randi(2)==2
    temp=Jx;
    Jx=-rot90(Jy);
    Jy=rot90(temp);
end 

%Take 2D FFT of currents
jx=fftshift(fft2(Jx));
jy=fftshift(fft2(Jy));

exp_factor = exp(-k_perp_2D*(z(k)));

bx = d*u0/2*(exp_factor.*(jy));
by = d*u0/2*(exp_factor.*(-jx));
bz = d*u0/2*1i*((exp_factor.*(-ky_2D.*jx+kx_2D.*jy)./k_perp_2D));

%optical diffraction

sf=sigma/2;

PSF=exp(-.25*sf^2*(kx_2D.^2+ky_2D.^2));

% bx_blur=bx;
% by_blur=by;
% bz_blur=bz;

bx_blur=bx.*PSF;
by_blur=by.*PSF;
bz_blur=bz.*PSF;

Bx(:,:) = real(ifft2(ifftshift(bx_blur)));
By(:,:) = real(ifft2(ifftshift(by_blur)));
Bz(:,:) = real(ifft2(ifftshift(bz_blur)));

% Adding gradients
if add_gradients

    Bx=Bx+Random_Gradient(max_grad_strength_xy,XX,YY);
    By=By+Random_Gradient(max_grad_strength_xy,XX,YY);
    Bz=Bz+Random_Gradient(max_grad_strength_z,XX,YY);
end 

% JDat_64(k,:,:,1)=BinImage(Jx(padding:end-padding,padding:end-padding),8);
% JDat_64(k,:,:,2)=BinImage(Jy(padding:end-padding,padding:end-padding),8);
% BDat_64(k,:,:,1)=BinImage(Bx(padding:end-padding,padding:end-padding),8);
% BDat_64(k,:,:,2)=BinImage(By(padding:end-padding,padding:end-padding),8);
% BDat_64(k,:,:,3)=BinImage(Bz(padding:end-padding,padding:end-padding),8);

%         JDat_128(k,:,:,1)=BinImage(Jx_f(padding:end-padding,padding:end-padding),8);
%         JDat_128(k,:,:,2)=BinImage(Jy_f(padding:end-padding,padding:end-padding),8);
%         BDat_128(k,:,:,1)=BinImage(Bx(padding:end-padding,padding:end-padding),8);
%         BDat_128(k,:,:,2)=BinImage(By(padding:end-padding,padding:end-padding),8);
%         BDat_128(k,:,:,3)=BinImage(Bz(padding:end-padding,padding:end-padding),8);
%         
JDat_256(k,:,:,1)=BinImage(Jx(padding:end-padding,padding:end-padding),2);
JDat_256(k,:,:,2)=BinImage(Jy(padding:end-padding,padding:end-padding),2);
BDat_256(k,:,:,1)=BinImage(Bx(padding:end-padding,padding:end-padding),2);
BDat_256(k,:,:,2)=BinImage(By(padding:end-padding,padding:end-padding),2);
BDat_256(k,:,:,3)=BinImage(Bz(padding:end-padding,padding:end-padding),2);
%         
%         JDat_512(k,:,:,1)=BinImage(Jx_f(padding:end-padding,padding:end-padding),2);
%         JDat_512(k,:,:,2)=BinImage(Jy_f(padding:end-padding,padding:end-padding),2);
%         BDat_512(k,:,:,1)=BinImage(Bx(padding:end-padding,padding:end-padding),2);
%         BDat_512(k,:,:,2)=BinImage(By(padding:end-padding,padding:end-padding),2);
%         BDat_512(k,:,:,3)=BinImage(Bz(padding:end-padding,padding:end-padding),2);

end 
% Dat_out_64.Jx=squeeze(JDat_64(:,:,:,1));
% Dat_out_64.Jy=squeeze(JDat_64(:,:,:,2));
% Dat_in_64.Bx=squeeze(BDat_64(:,:,:,1));
% Dat_in_64.By=squeeze(BDat_64(:,:,:,2));
% Dat_in_64.Bz=squeeze(BDat_64(:,:,:,3));
% Dat_in_64.z=z;
% Dat_in_64.Class=A;
% Dat_in_64.Numwire=Complex;
% Dat_in_64.LinearSize=device_x_dimension;
%     
%     Dat_out_128.Jx=squeeze(JDat_128(:,:,:,1));
%     Dat_out_128.Jy=squeeze(JDat_128(:,:,:,2));
%     Dat__in_128.Bx=squeeze(BDat_128(:,:,:,1));
%     Dat_in_128.By=squeeze(BDat_128(:,:,:,2));
%     Dat_in_128.Bz=squeeze(BDat_128(:,:,:,3));
%     Dat_in_128.z=z;
%     Dat_in_128.Class=A; 
%     Dat_in_128.Numwire=Complex;

Dat_out_256.Jx=squeeze(JDat_256(:,:,:,1));
Dat_out_256.Jy=squeeze(JDat_256(:,:,:,2));
Dat_in_256.Bx=squeeze(BDat_256(:,:,:,1));
Dat_in_256.By=squeeze(BDat_256(:,:,:,2));
Dat_in_256.Bz=squeeze(BDat_256(:,:,:,3));
Dat_in_256.z=z;
Dat_in_256.Class=A; 
Dat_in_256.Numwire=Complex;
%     
%     Dat_out_512.Jx=squeeze(JDat_512(:,:,:,1));
%     Dat_out_512.Jy=squeeze(JDat_512(:,:,:,2));
%     Dat_in_512.Bx=squeeze(BDat_512(:,:,:,1));
%     Dat_in_512.By=squeeze(BDat_512(:,:,:,2));
%     Dat_in_512.Bz=squeeze(BDat_512(:,:,:,3));
%     Dat_in_512.z=z;
%     Dat_in_512.Class=A; 
%     Dat_in_512.Numwire=Complex;


%% Saving the Data
%     filename_out=[save_dirc,'TrainingDat_128_out_',num2str(nb,'%03.f'),'.mat'];
%     filename_in=[save_dirc,'TrainingDat_128_in_',num2str(nb,'%03.f'),'.mat'];
%     save(filename_out,'Dat_out_128','-v7.3')
%     save(filename_in,'Dat_in_128','-v7.3')
%     
% filename_out64=save_dirc+"COMSOL_64_out_"+num2str(nb,'%03.f')+".mat";
% save(filename_out64,'Dat_out_64','-v7.3')
% 
% filename_in64=save_dirc+"COMSOL_64_in_"+num2str(nb,'%03.f')+".mat";
% save(filename_in64,'Dat_in_64','-v7.3')

%     filename_out=[save_dirc,'TrainingDat_512_out_',num2str(nb,'%03.f'),'.mat'];
%     filename_in=[save_dirc,'TrainingDat_512_in_',num2str(nb,'%03.f'),'.mat'];
%     save(filename_out,'Dat_out_512','-v7.3')
%     save(filename_in,'Dat_in_512','-v7.3')
%     
filename_out256=save_dirc+"COMSOL_256_out_"+num2str(nb,'%03.f')+".mat";
save(filename_out256,'Dat_out_256','-v7.3')

filename_in256=save_dirc+"COMSOL_256_in_"+num2str(nb,'%03.f')+".mat";
save(filename_in256,'Dat_in_256','-v7.3')
%    
toc
end 

%%

clf
colormap default 
k=9;
subplot(2,3,1)
imagesc(squeeze(JDat_64(k,:,:,1)))
title('Jx')
axis xy equal tight 
colorbar

subplot(2,3,2)
imagesc(squeeze(JDat_64(k,:,:,2)))
title('Jy')
axis xy equal tight 
colorbar

subplot(2,3,3)
imagesc(squeeze((JDat_64(k,:,:,1).^2+JDat_64(k,:,:,2).^2).^0.5))
title('J mag')
axis xy equal tight 
colorbar


subplot(2,3,4)
imagesc(squeeze(BDat_64(k,:,:,1)))
title('Bx')
axis xy equal tight 
colorbar

subplot(2,3,5)
imagesc(squeeze(BDat_64(k,:,:,2)))
title('By')
axis xy equal tight 
colorbar

subplot(2,3,6)
imagesc(squeeze(BDat_64(k,:,:,3)))
title('Bz')
axis xy equal tight 
colorbar

